package it.uniroma2.dsk.full;

import it.uniroma2.dsk.BoundedAbstractDS;
import it.uniroma2.dtk.op.IdealOperation;
import it.uniroma2.util.math.ArrayMath;
import it.uniroma2.util.math.MatrixUtils;

import java.util.HashMap;

public class BoundedFullDS<T> extends BoundedAbstractDS<T> {

	public BoundedFullDS(int randomOffset, int vectorsSize, double lambda, int bound, Class<?> opImplementationClass) throws Exception {
		super(randomOffset, vectorsSize, lambda, bound, opImplementationClass);
	}

	public BoundedFullDS(int randomOffset, int vectorsSize, double lambda, int bound, IdealOperation opImplementation) throws Exception {
		super(randomOffset, vectorsSize, lambda, bound, opImplementation);
	}
	
	protected HashMap<Integer, double[]> djStored;
	protected HashMap<Integer, double[]> deltajStored;

	@Override
	protected void initializeStore(T[] s) {
		djStored = new HashMap<Integer, double[]>(p*s.length);
		deltajStored = new HashMap<Integer, double[]>(p*s.length);
	}
	
	@Override
	protected double[] dRecursive(T[] s, int i) throws Exception {
		double[] result = MatrixUtils.uniformVector(vectorSize, 0);
		for (int j=1; j<=p; j++)
			result = ArrayMath.sum(result, dRecursive(s, i, j));
		return result;
	}
	
	protected double[] dRecursive(T[] s, int i, int j) throws Exception {
		int storeIndex = i*p+j;
		if (djStored.containsKey(storeIndex))
			return djStored.get(storeIndex);
		double[] result;
		if (i+j > s.length)
			result = MatrixUtils.uniformVector(vectorSize, 0);
		else if (j == 1)
			result = ArrayMath.scalardot(lambda, vectorProvider.generateRandomVector(getObjectCode(s[i])));
		else
			result = op(vectorProvider.generateRandomVector(getObjectCode(s[i])), deltaRecursive(s, i, j));
		djStored.put(storeIndex, result);
		return result;
	}
	
	protected double[] deltaRecursive(T[] s, int i, int j) throws Exception {
		int storeIndex = i*p+j;
		if (deltajStored.containsKey(storeIndex))
			return deltajStored.get(storeIndex);
		double[] result;
		if (j == 1 || i+j > s.length)
			result = MatrixUtils.uniformVector(vectorSize, 0);
		else
			result = ArrayMath.scalardot(lambda, ArrayMath.sum(deltaRecursive(s, i+1, j), dRecursive(s, i+1, j-1)));
		deltajStored.put(storeIndex, result);
		return result;
	}

}
